from datasets import load_dataset
from distilabel.models import vLLM
from distilabel.pipeline import Pipeline
from distilabel.steps.tasks import TextGeneration
from distilabel.steps import (KeepColumns, FormatTextGenerationSFT)
import shutil
import os
import pandas as pd
#import math
#from tqdm import tqdm


pipeline_cache = '/root/.cache/distilabel/pipelines/distill-qwen-32b-r1-socialqa'
if os.path.exists(pipeline_cache):
    shutil.rmtree(pipeline_cache)


prompt_template = """\
You will be given a context, a question, and the corresponding choices. Please reason step by step to answer this question, and put your final answer within \\boxed{}:

Story: {{ context }}
Question: {{question}}
Choices: {{containers}}
"""

dataset = load_dataset("json", data_files=".../SocialIqa/socialIWa_v1.4_trn_wDims.json", split="train[0:2000]")
dataset = dataset.remove_columns(['charmap'])


def add_combined_column(dataset):
    def combine_text(example):
        answer_A = example["answerA"]
        answer_B = example["answerB"]
        answer_C = example["answerC"]
       

        formatted_string = ""
        formatted_string += "A: " + answer_A + " " + "B: " + answer_B + " " + "C: " + answer_C
        example["containers"] = formatted_string
        example["entire_instruction"] = f"Context: {example['context']} Question: {example['question']} Choices: {formatted_string}"
        return example
    
    return dataset.map(combine_text)

# Apply the function to your dataset
dataset = add_combined_column(dataset)
print(dataset)
print(dataset[0])



model_id = ".../distill32B"

with Pipeline(
    name="distill-qwen-32b-r1-socialqa",
    description="A pipeline to generate data from a distilled r1 model",
) as pipeline:

    llm = vLLM(
        model=model_id,
        tokenizer=model_id,
        extra_kwargs={
            "tensor_parallel_size": 1,
            "max_model_len": 8192,
        },
        generation_kwargs={
            "temperature": 0.7,
            "max_new_tokens": 8192,
        },
    )


    text_generation = TextGeneration(
        llm=llm, 
        template=prompt_template,
        num_generations=1,
        input_batch_size=4,
        columns = ["context", "question", "containers"],
    )

    
    format_sft = FormatTextGenerationSFT(input_mappings={"instruction": "entire_instruction"})

    text_generation.connect(format_sft)
    



if __name__ == "__main__":
    distiset = pipeline.run(dataset=dataset)
    print(distiset)
    print(distiset['default']['train'][0]) 
    distiset.save_to_disk(".../SFTData/SocialIqa_May_2_test")
    distiset.load_from_disk(".../SFTData/SocialIqa_May_2_test")
    print(distiset)
    